#!/home/paddleocr/venv/bin/python

import logging
import os
import asyncio
import tempfile
import time

from aiohttp import web

logger = logging.getLogger(__name__)

from paddleocr import PaddleOCR

LANGUAGES = {
  'abq': ('阿巴扎文', 'abq'),
  'ady': ('阿迪格文', 'ady'),
  'af': ('南非荷兰文', 'af'),
  'anp': ('昂加文', 'ang'),
  'ar': ('阿拉伯文', 'ar'),
  'ar-SA': ('沙特阿拉伯文', 'sa'),
  'av': ('阿瓦尔文', 'ava'),
  'az': ('阿塞拜疆文', 'az'),
  'be': ('白俄罗斯文', 'be'),
  'bg': ('保加利亚文', 'bg'),
  'bh': ('比哈尔文', 'bh'),
  'bn': ('孟加拉文', 'bho'),
  'bs': ('波斯尼亚文', 'bs'),
  'cs': ('捷克文', 'cs'),
  'cy': ('威尔士文', 'cy'),
  'da': ('丹麦文', 'da'),
  'dar': ('达尔格瓦文', 'dar'),
  'de': ('德文', 'german'),
  'en': ('英文', 'en'),
  'es': ('西班牙文', 'es'),
  'et': ('爱沙尼亚文', 'et'),
  'fa': ('波斯文', 'fa'),
  'fr': ('法文', 'fr'),
  'ga': ('爱尔兰文', 'ga'),
  'gom': ('果阿邦孔卡尼文', 'gom'),
  'hi': ('印地文', 'hi'),
  'hr': ('克罗地亚文', 'hr'),
  'hu': ('匈牙利文', 'hu'),
  'id': ('印尼文', 'id'),
  'inh': ('印古什文', 'inh'),
  'is': ('冰岛文', 'is'),
  'it': ('意大利文', 'it'),
  'ja': ('日文', 'japan'),
  'ko': ('韩文', 'korean'),
  'ku': ('库尔德文', 'ku'),
  'lbe': ('拉克文', 'lbe'),
  'lez': ('列兹金文', 'lez'),
  'lt': ('立陶宛文', 'lt'),
  'lv': ('拉脱维亚文', 'lv'),
  'mag': ('摩揭陀文', 'mah'),
  'mai': ('迈蒂利文', 'mai'),
  'mi': ('毛利文', 'mi'),
  'mn': ('蒙古文', 'mn'),
  'mr': ('马拉地文', 'mr'),
  'ms': ('马来文', 'ms'),
  'mt': ('马耳他文', 'mt'),
  'ne': ('尼泊尔文', 'ne'),
  'new': ('尼瓦尔文', 'new'),
  'nl': ('荷兰文', 'nl'),
  'no': ('挪威文', 'no'),
  'oc': ('奥克文', 'oc'),
  'pl': ('波兰文', 'pl'),
  'pt': ('葡萄牙文', 'pt'),
  'ro': ('罗马尼亚文', 'ro'),
  'ru': ('俄文', 'ru'),
  'sck': ('萨德里文', 'sck'),
  'sk': ('斯洛伐克文', 'sk'),
  'sl': ('斯洛文尼亚文', 'sl'),
  'sq': ('阿尔巴尼亚文', 'sq'),
  'sr-Cyrl': ('塞尔维亚文（西里尔字母）', 'rs_cyrillic'),
  'sr-Latn': ('塞尔维亚文（拉丁字母）', 'rs_latin'),
  'sv': ('瑞典文', 'sv'),
  'sw': ('斯瓦希里文', 'sw'),
  'ta': ('泰米尔文', 'ta'),
  'tab': ('塔巴萨兰文', 'tab'),
  'te': ('泰卢固文', 'te'),
  'tl': ('他加禄文', 'tl'),
  'tr': ('土耳其文', 'tr'),
  'ug': ('维吾尔文', 'ug'),
  'uk': ('乌克兰文', 'uk'),
  'ur': ('乌尔都文', 'ur'),
  'uz': ('乌兹别克文', 'uz'),
  'vi': ('越南文', 'vi'),
  'zh-Hans': ('中文（简体）', 'ch'),
  'zh-Hant': ('中文（繁体）', 'chinese_cht'),
}

async def ocr(request):
  multipart = await request.multipart()

  with tempfile.TemporaryDirectory() as tmpdir:
    while True:
      part = await multipart.next()
      if not part:
        break

      if part.name == 'file':
        try:
          _, ext = part.filename.split('.')
        except ValueError:
          ext = 'png'
        tmpfile = os.path.join(tmpdir, f'image.{ext}')
        logger.info('image %s uploaded to %s', part.filename, tmpfile)
        with open(tmpfile, 'wb') as f:
          while True:
            data = await part.read_chunk()
            if not data:
              break
            f.write(part.decode(data))

      elif part.name == 'lang':
        lang = await part.text()

    st = time.time()
    res = await ocr_file(tmpfile, lang)
    logger.info('OCR time: %.3fs', time.time() - st)

  return web.json_response(
    {'result': res},
    headers = {
      'Access-Control-Allow-Origin': '*',
    },
  )

_ocr_caches = {}
_ocr_lock = asyncio.Lock()
async def ocr_file(img, lang):
  if (ocr := _ocr_caches.get(lang)) is None:
    ocr = PaddleOCR(
      lang = LANGUAGES[lang][1],
      use_angle_cls = True,
      show_log = False,
    )
    _ocr_caches[lang] = ocr
  try:
    async with _ocr_lock:
      loop = asyncio.get_running_loop()
      result = await loop.run_in_executor(None, ocr.ocr, img)
  except FileNotFoundError:
    return
  # change format back to 2.5; the first image?
  return result[0]

def setup_app(app):
  app.router.add_post('/api', ocr)

def main():
  import argparse

  from nicelogger import enable_pretty_logging

  parser = argparse.ArgumentParser(
    description = 'HTTP services for PaddleOCR',
  )
  parser.add_argument('--loglevel', default='info',
                      choices=['debug', 'info', 'warn', 'error'],
                      help='log level')
  args = parser.parse_args()

  enable_pretty_logging(args.loglevel.upper())

  app = web.Application()
  setup_app(app)

  web.run_app(app, path='/run/paddleocr/http.sock')

if __name__ == '__main__':
  main()
